import abc

import numpy as np
from metaworld.policies.policy import move


class Skill(abc.ABC):
    def __init__(self, target_value, atol=0.01, p=25, grab_effort=0, warmup_steps=25):
        self.target_value = target_value
        self.atol = atol
        self.p = p
        self.grab_effort = grab_effort

        self.Kp = 0.65
        self.Ki = 0.01
        self.Kd = 0.05
        self.dt = 0.01

        self.previous_error = np.zeros(3)
        self.integral = np.zeros(3)

        self.n_warmup_steps = warmup_steps
        self.curr_warmup_step = 1

    @abc.abstractmethod
    def get_name(self, curr_pos):
        pass

    @abc.abstractmethod
    def get_target_pos(self, curr_pos, target_value):
        pass

    @abc.abstractmethod
    def is_done(self, curr_pos):
        pass

    def smooth_target_pos(self, curr_pos, target_pos):
        if self.curr_warmup_step < self.n_warmup_steps:
            target_pos = curr_pos + (target_pos - curr_pos) * np.square(
                self.curr_warmup_step / self.n_warmup_steps
            )
            self.curr_warmup_step += 1

        return target_pos

    def get_delta_pos(self, curr_pos):
        target_pos = self.get_target_pos(curr_pos, self.target_value)
        target_pos = self.smooth_target_pos(curr_pos, target_pos)

        pid_pos = self._calculate_pid(curr_pos, target_pos)
        delta_pos = move(curr_pos, pid_pos, p=self.p)

        return delta_pos

    def get_grab_effort(self):
        return self.grab_effort

    def _calculate_pid(self, curr_pos, target_pos):
        error = target_pos - curr_pos
        P = self.Kp * error

        self.integral += error * self.dt
        I = self.Ki * self.integral

        # derivative = (error - self.previous_error) / self.dt
        # D = self.Kd * derivative

        self.previous_error = error
        pid_pos = curr_pos + P + I  # + D
        return pid_pos


class MoveXSkill(Skill):
    def __init__(self, target_value, atol=0.01, p=8.0, grab_effort=0, warmup_steps=25):
        super().__init__(target_value, atol, p, grab_effort, warmup_steps)

    def get_name(self, curr_pos):
        curr_value = curr_pos[0]
        return "move_" + ("left" if curr_value < self.target_value else "right")

    def get_target_pos(self, curr_pos, target_value):
        target_pos = np.copy(curr_pos)
        target_pos[0] = target_value
        return target_pos

    def is_done(self, curr_pos):
        curr_value = curr_pos[0]
        return np.isclose(curr_value, self.target_value, atol=self.atol)


class MoveYSkill(Skill):
    def __init__(self, target_value, atol=0.01, p=8.0, grab_effort=0, warmup_steps=25):
        super().__init__(target_value, atol, p, grab_effort, warmup_steps)

    def get_name(self, curr_pos):
        curr_value = curr_pos[1]
        return "move_" + ("front" if curr_value < self.target_value else "back")

    def get_target_pos(self, curr_pos, target_value):
        target_pos = np.copy(curr_pos)
        target_pos[1] = target_value
        return target_pos

    def is_done(self, curr_pos):
        curr_value = curr_pos[1]
        return np.isclose(curr_value, self.target_value, atol=self.atol)


class MoveZSkill(Skill):
    def __init__(self, target_value, atol=0.01, p=8.0, grab_effort=0, warmup_steps=25):
        super().__init__(target_value, atol, p, grab_effort, warmup_steps)

    def get_name(self, curr_pos):
        curr_value = curr_pos[2]
        return "move_" + ("up" if curr_value < self.target_value else "down")

    def get_target_pos(self, curr_pos, target_value):
        target_pos = np.copy(curr_pos)
        target_pos[2] = target_value
        return target_pos

    def is_done(self, curr_pos):
        curr_value = curr_pos[2]
        return np.isclose(curr_value, self.target_value, atol=self.atol)


class PushSkill(Skill):
    def __init__(self, target_value, atol=0.01, p=8.0, grab_effort=0, warmup_steps=25):
        super().__init__(target_value, atol, p, grab_effort, warmup_steps)

    def get_name(self, curr_pos):
        return f"push"

    def get_delta_pos(self, curr_pos):
        target_pos = self.get_target_pos(curr_pos, self.target_value)
        target_pos = self.smooth_target_pos(curr_pos, target_pos)

        magnitude = np.linalg.norm(target_pos - curr_pos)
        target_pos = curr_pos + (target_pos - curr_pos) / magnitude * 0.2

        self.delta_pos = move(curr_pos, target_pos, p=self.p)
        return self.delta_pos

    def get_target_pos(self, curr_pos, target_value):
        target_pos = target_value
        return target_pos

    def is_done(self, curr_pos):
        return np.all(np.isclose(curr_pos, self.target_value, atol=self.atol))


class PullSkill(Skill):
    def __init__(self, target_value, atol=0.01, p=8.0, grab_effort=0, warmup_steps=25):
        super().__init__(target_value, atol, p, grab_effort, warmup_steps)

    def get_name(self, curr_pos):
        return "pull"

    def get_delta_pos(self, curr_pos):
        target_pos = self.get_target_pos(curr_pos, self.target_value)
        target_pos = self.smooth_target_pos(curr_pos, target_pos)

        magnitude = np.linalg.norm(target_pos - curr_pos)
        target_pos = curr_pos + (target_pos - curr_pos) / magnitude * 0.2

        delta_pos = move(curr_pos, target_pos, p=self.p)
        return delta_pos

    def get_target_pos(self, curr_pos, target_value):
        target_pos = target_value
        return target_pos

    def is_done(self, curr_pos):
        return np.all(np.isclose(curr_pos, self.target_value, atol=self.atol))


class GrabSkill(Skill):
    def __init__(self, grab_effort=1, grab_duration=10):
        super().__init__(target_value=0, grab_effort=grab_effort)
        self.grab_duration = grab_duration
        self.grab_counter = 0

    def get_name(self, curr_pos):
        return "grab"

    def get_target_pos(self, curr_pos, target_value):
        target_pos = curr_pos
        return target_pos

    def is_done(self, curr_pos):
        if self.grab_counter >= self.grab_duration:
            return True
        self.grab_counter += 1
        return False
